// --------------------------------------------------------------------------------------------------------------------
// <copyright file="ExpressionExpander.cs" company="Open Trader">
//   Copyright (c) David Denis (david.denis@systemathics.com)
// </copyright>
// <summary>
//   |  Open Trader - The Open Source Systematic Trading Platform
//   |
//   |  This program is free software: you can redistribute it and/or modify
//   |  it under the terms of the GNU General Public License as published by
//   |  the Free Software Foundation, either version 2 of the License, or
//   |  (at your option) any later version.
//   |
//   |  This program is distributed in the hope that it will be useful,
//   |  but WITHOUT ANY WARRANTY; without even the implied warranty of
//   |  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//   |  GNU General Public License for more details.
//   |
//   |  You should have received a copy of the GNU General Public License
//   |  along with this program.  If not, see http://www.gnu.org/licenses
//   |
//   |  Up to date informations about Open Trader can be found at :
//   |    http://opentrader.org
//   |    http://opentrader.codeplex.com
//   |
//   |  For professional services, please visit us at :
//   |    http://www.systemathics.com
// </summary>
// --------------------------------------------------------------------------------------------------------------------

namespace Org.OpenTrader.Framework.LinqKit
{
    #region Using Directives

    using System;
    using System.Collections.Generic;
    using System.Linq.Expressions;
    using System.Reflection;

    #endregion

    /// <summary>
    /// Custom expresssion visitor for ExpandableQuery. This expands calls to Expression.Compile() and
    /// collapses captured lambda references in subqueries which LINQ to SQL can't otherwise handle.
    /// </summary>
    internal class ExpressionExpander : ExpressionVisitor
    {
        // Replacement parameters - for when invoking a lambda expression.
        #region Constants and Fields

        /// <summary>
        /// The replace vars.
        /// </summary>
        private readonly Dictionary<ParameterExpression, Expression> replaceVars;

        #endregion

        #region Constructors and Destructors

        /// <summary>
        /// Initializes a new instance of the <see cref="ExpressionExpander"/> class.
        /// </summary>
        internal ExpressionExpander()
        {
        }

        /// <summary>
        /// Initializes a new instance of the <see cref="ExpressionExpander"/> class.
        /// </summary>
        /// <param name="replaceVars">
        /// The replace vars.
        /// </param>
        private ExpressionExpander(Dictionary<ParameterExpression, Expression> replaceVars)
        {
            this.replaceVars = replaceVars;
        }

        #endregion

        #region Methods

        /// <summary>
        /// Flatten calls to Invoke so that Entity Framework can understand it. Calls to Invoke are generated
        /// by PredicateBuilder.
        /// </summary>
        /// <param name="iv">
        /// The iv.
        /// </param>
        protected override Expression VisitInvocation(InvocationExpression iv)
        {
            var target = iv.Expression;
            if (target is MemberExpression)
            {
                target = this.TransformExpr((MemberExpression)target);
            }

            if (target is ConstantExpression)
            {
                target = ((ConstantExpression)target).Value as Expression;
            }

            var lambda = (LambdaExpression)target;

            Dictionary<ParameterExpression, Expression> replaceVars;
            if (this.replaceVars == null)
            {
                replaceVars = new Dictionary<ParameterExpression, Expression>();
            }
            else
            {
                replaceVars = new Dictionary<ParameterExpression, Expression>(this.replaceVars);
            }

            try
            {
                for (var i = 0; i < lambda.Parameters.Count; i++)
                {
                    replaceVars.Add(lambda.Parameters[i], iv.Arguments[i]);
                }
            }
            catch (ArgumentException ex)
            {
                throw new InvalidOperationException(
                    "Invoke cannot be called recursively - try using a temporary variable.", ex);
            }

            return new ExpressionExpander(replaceVars).Visit(lambda.Body);
        }

        /// <summary>
        /// The visit member access.
        /// </summary>
        /// <param name="m">
        /// The m.
        /// </param>
        /// <returns>
        /// </returns>
        protected override Expression VisitMemberAccess(MemberExpression m)
        {
            // Strip out any references to expressions captured by outer variables - LINQ to SQL can't handle these:
            if (m.Member.DeclaringType.Name.StartsWith("<>"))
            {
                return this.TransformExpr(m);
            }

            return base.VisitMemberAccess(m);
        }

        /// <summary>
        /// The visit method call.
        /// </summary>
        /// <param name="m">
        /// The m.
        /// </param>
        /// <returns>
        /// </returns>
        /// <exception cref="InvalidOperationException">
        /// </exception>
        protected override Expression VisitMethodCall(MethodCallExpression m)
        {
            if (m.Method.Name == "Invoke" && m.Method.DeclaringType == typeof(Extensions))
            {
                var target = m.Arguments[0];
                if (target is MemberExpression)
                {
                    target = this.TransformExpr((MemberExpression)target);
                }

                if (target is ConstantExpression)
                {
                    target = ((ConstantExpression)target).Value as Expression;
                }

                var lambda = (LambdaExpression)target;

                Dictionary<ParameterExpression, Expression> replaceVars;
                if (this.replaceVars == null)
                {
                    replaceVars = new Dictionary<ParameterExpression, Expression>();
                }
                else
                {
                    replaceVars = new Dictionary<ParameterExpression, Expression>(this.replaceVars);
                }

                try
                {
                    for (var i = 0; i < lambda.Parameters.Count; i++)
                    {
                        replaceVars.Add(lambda.Parameters[i], m.Arguments[i + 1]);
                    }
                }
                catch (ArgumentException ex)
                {
                    throw new InvalidOperationException(
                        "Invoke cannot be called recursively - try using a temporary variable.", ex);
                }

                return new ExpressionExpander(replaceVars).Visit(lambda.Body);
            }

            // Expand calls to an expression's Compile() method:
            if (m.Method.Name == "Compile" && m.Object is MemberExpression)
            {
                var me = (MemberExpression)m.Object;
                var newExpr = this.TransformExpr(me);
                if (newExpr != me)
                {
                    return newExpr;
                }
            }

            // Strip out any nested calls to AsExpandable():
            if (m.Method.Name == "AsExpandable" && m.Method.DeclaringType == typeof(Extensions))
            {
                return m.Arguments[0];
            }

            return base.VisitMethodCall(m);
        }

        /// <summary>
        /// The visit parameter.
        /// </summary>
        /// <param name="p">
        /// The p.
        /// </param>
        /// <returns>
        /// </returns>
        protected override Expression VisitParameter(ParameterExpression p)
        {
            if ((this.replaceVars != null) && this.replaceVars.ContainsKey(p))
            {
                return this.replaceVars[p];
            }
            else
            {
                return base.VisitParameter(p);
            }
        }

        /// <summary>
        /// The transform expr.
        /// </summary>
        /// <param name="input">
        /// The input.
        /// </param>
        /// <returns>
        /// </returns>
        private Expression TransformExpr(MemberExpression input)
        {
            // Collapse captured outer variables
            if (input == null || !(input.Member is FieldInfo) || !input.Member.ReflectedType.IsNestedPrivate ||
                !input.Member.ReflectedType.Name.StartsWith("<>"))
            {
                // captured outer variable
                return input;
            }

            if (input.Expression is ConstantExpression)
            {
                var obj = ((ConstantExpression)input.Expression).Value;
                if (obj == null)
                {
                    return input;
                }

                var t = obj.GetType();
                if (!t.IsNestedPrivate || !t.Name.StartsWith("<>"))
                {
                    return input;
                }

                var fi = (FieldInfo)input.Member;
                var result = fi.GetValue(obj);
                if (result is Expression)
                {
                    return this.Visit((Expression)result);
                }
            }

            return input;
        }

        #endregion
    }
}